{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Computes LIME explaination\n",
    "Uses the AIX360 AI Exmplainability 360 toolkit to compute image classification explainations heathmaps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip3 install scikit-learn==0.24.1 aif360==0.3.0  tensorflow==2.4.0 nodejs==0.1.1 ipywidgets==7.6.3 lime==0.2.0.1 wget==3.2 #aix360==0.2.1 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wget\n",
    "wget.download(\n",
    "    'https://raw.githubusercontent.com/'\n",
    "    'elyra-ai/component-library/master/claimed_utils.py'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from claimed_utils import unzip\n",
    "import os.path\n",
    "import glob\n",
    "from lime import lime_image\n",
    "from lime.wrappers.scikit_image import SegmentationAlgorithm\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1 import ImageGrid\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @dependency codait_utils.ipynb\n",
    "# @dependency model.zip\n",
    "# @dependency data.zip\n",
    "# @param model zip file name\n",
    "# @param data zip file name\n",
    "# @returns random heatmap image\n",
    "# objects inline in jupyter (currently 4,TODO make specifiable)\n",
    "# (next version will return a zip folder with images+heatmaps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_zip = os.environ.get('model_zip', 'model.zip')\n",
    "data_zip = os.environ.get('data_zip', 'data.zip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 55.657074,
     "end_time": "2021-01-22T16:50:23.433635",
     "exception": false,
     "start_time": "2021-01-22T16:49:27.776561",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# !jupyter labextension install @jupyter-widgets/jupyterlab-manager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unzip('.', model_zip)\n",
    "unzip('.', data_zip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 6.978567,
     "end_time": "2021-01-22T16:50:32.250261",
     "exception": false,
     "start_time": "2021-01-22T16:50:25.271694",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = keras.models.load_model('model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.322614,
     "end_time": "2021-01-22T16:50:32.686989",
     "exception": false,
     "start_time": "2021-01-22T16:50:32.364375",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "folder = glob.glob(\"data/*\")\n",
    "num_classes = len(folder)\n",
    "\n",
    "batch_size = 32\n",
    "img_height = 400\n",
    "img_width = 400\n",
    "input_shape = (img_width, img_height)\n",
    "\n",
    "\n",
    "train_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
    "    'data',\n",
    "    validation_split=0.2,\n",
    "    subset=\"training\",\n",
    "    seed=123,\n",
    "    image_size=(img_height, img_width),\n",
    "    batch_size=batch_size)\n",
    "\n",
    "val_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
    "    'data',\n",
    "    validation_split=0.2,\n",
    "    subset=\"validation\",\n",
    "    seed=123,\n",
    "    image_size=(img_height, img_width),\n",
    "    batch_size=batch_size)\n",
    "\n",
    "train_ds = train_ds.map(lambda x, y: (x, tf.one_hot(y, depth=num_classes)))\n",
    "val_ds = val_ds.map(lambda x, y: (x, tf.one_hot(y, depth=num_classes)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.832315,
     "end_time": "2021-01-22T16:50:40.691850",
     "exception": false,
     "start_time": "2021-01-22T16:50:39.859535",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "explainer = lime_image.LimeImageExplainer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_ds_shuffeled = val_ds.shuffle(1)\n",
    "for image, label in val_ds:\n",
    "    break\n",
    "label = label[0]\n",
    "label = tf.reshape(label, [1, len(label)])\n",
    "tf.math.argmax(label, 1)\n",
    "label = tf.math.argmax(label, 1).numpy()[0]  # de-onehotencode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.162176,
     "end_time": "2021-01-22T16:50:41.017556",
     "exception": false,
     "start_time": "2021-01-22T16:50:40.855380",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_random_image():\n",
    "    for image, label in val_ds:\n",
    "        break\n",
    "    image = image.numpy().astype(int)\n",
    "    image = image[0]\n",
    "    label = label[0]\n",
    "    label = tf.reshape(label, [1, len(label)])\n",
    "    image = image.reshape(img_height, img_width, 3)\n",
    "    image_for_model = image.reshape(1, img_height, img_width, 3)\n",
    "    label = tf.math.argmax(label, 1).numpy()[0]  # de-onehotencode\n",
    "    prediction = model.predict(image_for_model)\n",
    "    prediction = tf.math.argmax(prediction, 1).numpy()[0]  # de-onehotencode\n",
    "    return image, label, prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.123489,
     "end_time": "2021-01-22T16:50:41.256283",
     "exception": false,
     "start_time": "2021-01-22T16:50:41.132794",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_explaination_as_mask(image, label):\n",
    "    segmentation_fn = SegmentationAlgorithm(algo_type='slic')\n",
    "    explanation = explainer.explain_instance(\n",
    "        image,\n",
    "        model.predict,\n",
    "        segmentation_fn=segmentation_fn\n",
    "    )\n",
    "    print(label)\n",
    "    return explanation.get_image_and_mask(label)[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def merge_image_with_mask(image, mask):\n",
    "    np_mask = np.array(mask.astype(int) * 255, np.uint8)\n",
    "    np_merged = np.empty(image.shape, np.uint8)\n",
    "    np_merged[:, :, 0] = np.maximum(image[:, :, 0], np_mask)\n",
    "    np_merged[:, :, 1] = np.maximum(image[:, :, 1], np_mask)\n",
    "    np_merged[:, :, 2] = np.maximum(image[:, :, 2], np_mask)\n",
    "    return np_merged"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(12., 12.))\n",
    "grid = ImageGrid(fig, 111,  # similar to subplot(111)\n",
    "                 nrows_ncols=(2, 2),\n",
    "                 axes_pad=0.1,  # pad between axes in inch.\n",
    "                 )\n",
    "\n",
    "for ax in grid:\n",
    "    try:\n",
    "        image, label, prediction = get_random_image()\n",
    "        mask = get_explaination_as_mask(image, label)\n",
    "        masked_image = merge_image_with_mask(image, mask)\n",
    "        ax.imshow(masked_image)\n",
    "    except KeyError:\n",
    "        print('Model label, Lame label missmatch')\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 106.108733,
   "end_time": "2021-01-22T16:51:10.430569",
   "environment_variables": {},
   "exception": null,
   "input_path": "/home/jovyan/work/elyra-classification/lime.ipynb",
   "output_path": "/home/jovyan/work/elyra-classification/lime.ipynb",
   "parameters": {},
   "start_time": "2021-01-22T16:49:24.321836",
   "version": "2.2.2"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "04815b6c174c4911aeb44205917f2179": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "05c301d3af534851912aca10a400c2ef": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.2.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "1.2.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "overflow_x": null,
       "overflow_y": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "065e9a24ba574c8388e8b34415d3a177": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "1.5.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_efb39be96ba348d0b85cc6a844d3bbce",
        "IPY_MODEL_52dc9534522c41ee948e351732c1224b",
        "IPY_MODEL_30a8a9e345d048a38926dbcbb30c102f"
       ],
       "layout": "IPY_MODEL_60ca3839d1af4bf8b6b8cc643309f863"
      }
     },
     "30a8a9e345d048a38926dbcbb30c102f": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "1.5.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_tooltip": null,
       "layout": "IPY_MODEL_78695ef1716a4264b558c466a6150765",
       "placeholder": "​",
       "style": "IPY_MODEL_8f81e34e8cd546c4b7367dea98a5893f",
       "value": " 1000/1000 [00:24&lt;00:00, 40.16it/s]"
      }
     },
     "52dc9534522c41ee948e351732c1224b": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "1.5.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_tooltip": null,
       "layout": "IPY_MODEL_05c301d3af534851912aca10a400c2ef",
       "max": 1000,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_04815b6c174c4911aeb44205917f2179",
       "value": 1000
      }
     },
     "60ca3839d1af4bf8b6b8cc643309f863": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.2.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "1.2.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "overflow_x": null,
       "overflow_y": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "78695ef1716a4264b558c466a6150765": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.2.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "1.2.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "overflow_x": null,
       "overflow_y": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "8e2b113a9f414b6395320be8b1a0a5bd": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "DescriptionStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "StyleView",
       "description_width": ""
      }
     },
     "8f81e34e8cd546c4b7367dea98a5893f": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "DescriptionStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "StyleView",
       "description_width": ""
      }
     },
     "bf9a01df5cbb437d962131258e6ec019": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.2.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "1.2.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "overflow_x": null,
       "overflow_y": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "efb39be96ba348d0b85cc6a844d3bbce": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "1.5.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_tooltip": null,
       "layout": "IPY_MODEL_bf9a01df5cbb437d962131258e6ec019",
       "placeholder": "​",
       "style": "IPY_MODEL_8e2b113a9f414b6395320be8b1a0a5bd",
       "value": "100%"
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
